3951
5375
Was ist der einfachste Weg, um einen mit n-Werten gefüllten Formtensor (batch_size, height, width) in einen Tensor der Form (batch_size, n, height, width) umzuwandeln?
Ich habe unten eine Lösung erstellt, aber es scheint, dass es einfachere und schnellere Möglichkeiten gibt, dies zu tun
def batch_tensor_to_onehot (tnsr, classes):
tnsr = tnsr.unsqueeze (1)
res = []
für cls in range (Klassen):
res.append ((tnsr == cls) .long ())
return torch.cat (res, dim = 1) 
Sie können torch.nn.functional.one_hot verwenden.
Für Ihren Fall:
a = torch.nn.functional.one_hot (tnsr, num_classes = classes)
out = a.permute (0, 3, 1, 2)
|
Sie können auch Tensor.scatter_ verwenden, das .permute vermeidet, aber wahrscheinlich schwieriger zu verstehen ist als die von @Alpha vorgeschlagene einfache Methode.
def batch_tensor_to_onehot (tnsr, classes):
Ergebnis = torch.zeros (tnsr.shape [0], Klassen, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device)
result.scatter_ (1, tnsr.unsqueeze (1), 1)
Ergebnis zurückgeben
Benchmarking-Ergebnisse
Ich war neugierig und beschloss, die drei Ansätze zu vergleichen. Ich fand heraus, dass es keinen signifikanten relativen Unterschied zwischen den vorgeschlagenen Methoden in Bezug auf Chargengröße, Breite oder Höhe zu geben scheint. In erster Linie war die Anzahl der Klassen der entscheidende Faktor. Natürlich kann wie bei jedem Benchmark-Kilometerstand variieren.
Die Benchmarks wurden unter Verwendung von Zufallsindizes und unter Verwendung von Chargengröße, Höhe, Breite = 100 gesammelt. Jedes Experiment wurde 20 Mal wiederholt, wobei der Durchschnitt angegeben wurde. Das Experiment num_classes = 100 wird einmal ausgeführt, bevor ein Profil zum Aufwärmen erstellt wird.
Die CPU-Ergebnisse zeigen, dass die ursprüngliche Methode wahrscheinlich am besten für num_classes kleiner als 30 war, während für die GPU der Scatter_-Ansatz am schnellsten zu sein scheint.
Tests unter Ubuntu 18.04, NVIDIA 2060 Super, i7-9700K
Der für das Benchmarking verwendete Code ist unten angegeben:
Fackel importieren
aus tqdm importiere tqdm
Importzeit
importiere matplotlib.pyplot als plt
def batch_tensor_to_onehot_slavka (tnsr, classes):
tnsr = tnsr.unsqueeze (1)
res = []
für cls in range (Klassen):
res.append ((tnsr == cls) .long ())
return torch.cat (res, dim = 1)
def batch_tensor_to_onehot_alpha (tnsr, classes):
Ergebnis = torch.nn.functional.one_hot (tnsr, num_classes = classes)
return result.permute (0, 3, 1, 2)
def batch_tensor_to_onehot_jodag (tnsr, classes):
Ergebnis = torch.zeros (tnsr.shape [0], Klassen, * tnsr.shape [1:], dtype = torch.long, device = tnsr.device)
result.scatter_ (1, tnsr.unsqueeze (1), 1)
Ergebnis zurückgeben
def main ():
num_classes = [2, 10, 25, 50, 100]
Höhe = 100
Breite = 100
bs = [100] * 20
für d in ['cpu', 'cuda']:
times_slavka = []
times_alpha = []
times_jodag = []
Aufwärmen = Richtig
für c in tqdm ([num_classes [-1]] + num_classes, ncols = 0):
tslavka = 0
talpha = 0
tjodag = 0
für b in bs:
tnsr = torch.randint (c, (b, Höhe, Breite)). to (Gerät = d)
t0 = time.time ()
y = batch_tensor_to_onehot_slavka (tnsr, c)
torch.cuda.synchronize ()
tslavka + = time.time () - t0
wenn nicht Aufwärmen:
times_slavka.append (tslavka / len (bs))
für b in bs:
tnsr = torch.randint (c, (b, Höhe, Breite)). to (Gerät = d)
t0 = time.time ()
y = batch_tensor_to_onehot_alpha (tnsr, c)
torch.cuda.synchronize ()
talpha + = time.time () - t0
wenn nicht Aufwärmen:
times_alpha.append (talpha / len (bs))
für b in bs:
tnsr = torch.randint (c, (b, Höhe, Breite)). to (Gerät = d)
t0 = time.time ()
y = batch_tensor_to_onehot_jodag (tnsr, c)
torch.cuda.synchronize ()
tjodag + = time.time () - t0
wenn nicht Aufwärmen:
times_jodag.append (tjodag / len (bs))
Aufwärmen = Falsch
fig = plt.figure ()
ax = fig.subplots ()
ax.plot (num_classes, times_slavka, label = 'Slavka-cat')
ax.plot (num_classes, times_alpha, label = 'Alpha-one_hot')
ax.plot (num_classes, times_jodag, label = 'jodag-dispers_')
ax.set_xlabel ('num_classes')
ax.set_ylabel ('Zeit (en)')
ax.set_title (f '{d} Benchmark')
ax.legend ()
plt.savefig (f '{d} .png')
plt.show ()
if __name__ == "__main__":
Main()
|
Deine Antwort
StackExchange.ifUsing ("Editor", function () {
StackExchange.using ("externalEditor", function () {
StackExchange.using ("Snippets", function () {
StackExchange.snippets.init ();
});
});
}, "Code Ausschnitte");
StackExchange.ready (function () {
var channelOptions = {
Tags: "" .split (""),
id: "1"
};
initTagRenderer ("". split (""), "" .split (""), channelOptions);
StackExchange.using ("externalEditor", function () {
// Editor muss nach Snippets ausgelöst werden, wenn Snippets aktiviert sind
if (StackExchange.settings.snippets.snippetsEnabled) {
StackExchange.using ("Snippets", function () {
createEditor ();
});
}}
sonst {
createEditor ();
}}
});
Funktion createEditor () {
StackExchange.prepareEditor ({
useStacksEditor: false,
heartbeatType: 'Antwort',
autoActivateHeartbeat: false,
convertImagesToLinks: true,
noModals: wahr,
showLowRepImageUploadWarning: true,
Ruf zu PostImages: 10,
bindNavPrevention: true,
Postfix: "",
imageUploader: {
brandingHtml: "Powered by \ u003ca href =" https: //imgur.com/ "\ u003e \ u003csvg class =" svg-icon "width =" 50 "height =" 18 "viewBox = "0 0 50 18" fill = "none" xmlns = "http: //www.w3.org/2000/svg" \ u003e \ u003cpath d = "M46.1709 9.17788C46.1709 8.26454 46,2665 7,94324 47,1084 7.58816C47.4091 7,46349 47,7169 7,36433 48,0099 7.26993C48.9099 6,97997 49,672 6,73443 49,672 5.93063C49.672 5,22043 48,9832 4,61182 48,1414 4.61182C47.4335 4,61182 46,7256 4,91628 46,0943 5.50789C45.7307 4,9328 45,2525 4,66231 44,6595 4.66231C43.6264 4,66231 43,1481 5,28821 43.1481 6.59048V11.9512C43.1481 13.2535 43.6264 13.8962 44.6595 13.8962C45.6924 13.8962 46.1709 13.253546.1709 11.9512V9.17788Z \ "/ \ u003e \ u003cpath d =" M32.492 10.1419C32.492 12.6954 34.1182 14.0484 37.0451 14.0484C39.9723 14.0484 41.5985 12.6954 41.5985 10.1419V6.59049C82.21.21.21.21.31.6 38.5948 5.28821 38.5948 6.59049V9.60062C38.5948 10.8521 38.2696 11.5455 37.0451 11.5455C35.8209 11.5455 35.4954 10.8521 35.4954 9.60062V6.59049C35.4954 5.28821 35.0173 4.66232 34.003 32.9 Füllregel = "Evenodd" Clip-Regel = "Evenodd" d = "M25.6622 17.6335C27.8049 17.6335 29.3739 16.9402 30.2537 15.6379C30.8468 14.7755 30.9615 13.5579 30.9615 11.9512V6.59049C30.9615 5.28821 30.4833 29,4502 4.66231C28.9913 4,66231 28,4555 4,94978 28,1109 5.50789C27.499 4,86533 26,7335 4,56087 25,7005 4.56087C23.1369 4,56087 21,0134 6,57349 21,0134 9.27932C21.0134 11,9852 23,003 13,913 25,3754 13.913C26.5612 13,913 27,4607 13,4902 28,1109 12.6616C28.1109 12,7229 28,1161 12,7799 28,121 12,8346 C28. 1256 12.8854 28.1301 12.9342 28.1301 12.983C28.1301 14.4373 27.2502 15.2321 25.777 15.2321C24.8349 15.2321 24.1352 14.9821 23.5661 14.7787C23.176 14.6393 22.8472 14.5218 22.5437 14.521822.21.23.73. C24.1317 7.94324 24.9928 7.09766 26.1024 7.09766C27.2119 7.09766 28.0918 7.94324 28.0918 9.27932C28.0918 10.6321 27.2311 11.5116 26.1024 11.5116C24.9737 11.5116 24.1317 10.6491 24.1317 9.27932Z \ 8045 13.2535 17.2637 13.8962 18.2965 13.8962C19.3298 13.8962 19.8079 13.2535 19.8079 11.9512V8.12928C19.8079 5.82936 18.4879 4.62866 16.4027 4.62866C15.1594 4.62866 14.279 4.98375 13.3609 5.88013C12.653 4.63.6. 58314 4.9328 7.10506 4.66232 6.51203 4.66232C5.47873 4.66232 5.00066 5.28821 5.00066 6.59049V11.9512C5.00066 13.2535 5.47873 13.8962 6.51203 13.8962C7.54479 13.8962 8.0232 13 .2535 8.0232 11.9512V8.90741C8.0232 7.58817 8.44431 6.91179 9.53458 6.91179C10.5104 6.91179 10.893 7.58817 10.893 8.94108V11.9512C10.893 13.2535 11.3711 13.8962 12.4044 13.8962C13.437.13.9 C16.4027 6.91179 16.8045 7.58817 16.8045 8.94108V11.9512Z \ "/ \ u003e \ u003cpath d =" M3.31675 6.59049C3.31675 5.28821 2.83866 4.66232 1.82471 4.66232C0.791758 4.66232 0.313353 13.513 1,82471 13,8962C2,85798 13,8962 3,31675 13,2535 3,31675 11,9512V6,59049Z \ "/ \ u003e \ u003cpath d =" M1,87209 0,400291C0,843612 0,400291 0 1,1159 0 1,98861C0 2,87869 0,822846 3,57676 387 C3.7234 1.1159 2.90056 0.400291 1.87209 0.400291Z "fill =" # 1BB76E "/ \ u003e \ u003c / svg \ u003e \ u003c / a \ u003e",
contentPolicyHtml: "Benutzerbeiträge lizenziert unter \ u003ca href = \" https: //stackoverflow.com/help/licensing \ "\ u003ecc by-sa \ u003c / a \ u003e \ u003ca href = \" https://stackoverflow.com / legal / content-policy \ "\ u003e (Inhaltsrichtlinie) \ u003c / a \ u003e",
allowUrls: true
},
onDemand: wahr,
discardSelector: ".discard-answer"
, instantShowMarkdownHelp: true, enableTables: true, enableSnippets: true
});
}}
});
Vielen Dank für Ihre Antwort auf Stack Overflow!
Bitte beantworten Sie die Frage unbedingt. Geben Sie Details an und teilen Sie Ihre Forschung!
Aber vermeiden Sie ...
Um Hilfe bitten, Klarheit schaffen oder auf andere Antworten antworten.
Aussagen auf der Grundlage von Meinungen machen; Unterstützen Sie sie mit Referenzen oder persönlichen Erfahrungen.
Weitere Informationen finden Sie in unseren Tipps zum Schreiben großartiger Antworten.
Entwurf gespeichert
Entwurf verworfen
Anmelden oder anmelden
StackExchange.ready (function () {
StackExchange.helpers.onClickDraftSave ('# login-link');
});
Melden Sie sich mit Google an
Melde dich über Facebook an
Melden Sie sich mit E-Mail und Passwort an
einreichen
Post als Gast
Name
Email
Erforderlich, aber nie gezeigt
StackExchange.ready (
function () {
StackExchange.openid.initPostLogin ('. New-post-login', 'https% 3a% 2f% 2fstackoverflow.com% 2fquestions% 2f62245173% 2fpytorch-transform-tensor-to-one-hot% 23new-answer', 'question_page' );
}}
);
Post als Gast
Name
Email
Erforderlich, aber nie gezeigt
Veröffentlichen Sie Ihre Antwort
Verwerfen
Durch Klicken auf "Antwort posten" stimmen Sie unseren Nutzungsbedingungen, Datenschutzbestimmungen und Cookie-Richtlinien zu
Nicht die Antwort, die Sie suchen? Durchsuchen Sie andere Fragen mit dem Tag Python Pytorch Tensor One-Hot-Codierung oder stellen Sie Ihre eigene Frage.